{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "def proc_direct(direct):\n",
    "    if 'younger' in direct and 'older' not in direct:\n",
    "        direct_pred = 0\n",
    "    elif 'older' in direct and 'younger' not in direct:\n",
    "        direct_pred = 2\n",
    "    elif 'same age' in direct:\n",
    "        direct_pred = 1\n",
    "    elif 'cannot decide' in direct:\n",
    "        direct_pred = 3\n",
    "    else:\n",
    "        direct_pred = -1\n",
    "    return direct_pred\n",
    "\n",
    "def proc_cot(cot):\n",
    "    if 'final answer:' not in cot:\n",
    "        # print(cot)\n",
    "        # print(\"*****************\")\n",
    "        cot_pred = -1\n",
    "    else:\n",
    "        cot = cot.split('final answer:')[-1].strip()\n",
    "        if 'cannot decide' in cot:\n",
    "            cot_pred = 3\n",
    "        elif 'younger' in cot and 'older' not in cot:\n",
    "            cot_pred = 0\n",
    "        elif 'older' in cot and 'younger' not in cot:\n",
    "            cot_pred = 2\n",
    "        elif 'same age' in cot:\n",
    "            cot_pred = 1\n",
    "        else:\n",
    "            cot_pred = -1\n",
    "    return cot_pred\n",
    "\n",
    "def evaluate(pred, gold, correct_inds=False):\n",
    "    inds = []\n",
    "    eval_dict = {\"invalid\": 0, \"correct\": 0, \"wrong\": 0, \"cannot decide\": 0}\n",
    "    for i in range(len(pred)):\n",
    "        p,g = pred[i], gold[i]\n",
    "        if p == -1:\n",
    "            eval_dict[\"invalid\"] += 1\n",
    "        elif p == 3:\n",
    "            eval_dict[\"cannot decide\"] += 1\n",
    "        else:\n",
    "            assert p in [0,1,2] and g in [0,1,2]\n",
    "            if p==g:\n",
    "                inds.append(i)\n",
    "                eval_dict['correct'] += 1\n",
    "            else:\n",
    "                eval_dict['wrong'] += 1\n",
    "    for key, val in eval_dict.items():\n",
    "        eval_dict[key] = round(val/len(pred), 3)\n",
    "    print(eval_dict)\n",
    "    if correct_inds:\n",
    "        return inds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "======no retrieval augmentation======\n",
      "model: gemini, direct QA performance: {'invalid': 0.0, 'correct': 0.287, 'wrong': 0.513, 'cannot decide': 0.2}\n",
      "model: gemini, CoT performance: {'invalid': 0.0, 'correct': 0.113, 'wrong': 0.18, 'cannot decide': 0.707}\n"
     ]
    }
   ],
   "source": [
    "# no retrieval augmentation\n",
    "print(\"======no retrieval augmentation======\")\n",
    "model = 'gemini'\n",
    "\n",
    "direct_pred_l = []\n",
    "cot_pred_l = []\n",
    "cot_whole = []\n",
    "gold_l = []\n",
    "\n",
    "for k in range(150):\n",
    "    with open(\"LLM/{}_directna_{}.txt\".format(model, k)) as f:\n",
    "        direct = f.read().strip()\n",
    "    with open(\"LLM/{}_cot_{}.txt\".format(model, k)) as f:\n",
    "        cot = f.read().strip().lower()\n",
    "    with open(\"LLM/answer_{}.txt\".format(k)) as f:\n",
    "        ans = f.read().strip()\n",
    "    direct_pred_l.append(proc_direct(direct))\n",
    "    cot_whole.append(cot)\n",
    "    cot_pred_l.append(proc_cot(cot))\n",
    "    gold_l.append(int(ans))\n",
    "\n",
    "print(\"model: {}, direct QA performance:\".format(model), end=\" \")\n",
    "evaluate(direct_pred_l, gold_l)\n",
    "print(\"model: {}, CoT performance:\".format(model), end=\" \")\n",
    "correct_inds = evaluate(cot_pred_l, gold_l, correct_inds=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "======retrieval augmentation======\n",
      "model: gpt4turbo, direct QA performance: {'invalid': 0.0, 'correct': 0.333, 'wrong': 0.667, 'cannot decide': 0.0}\n",
      "model: gpt4turbo, CoT performance: {'invalid': 0.04, 'correct': 0.313, 'wrong': 0.52, 'cannot decide': 0.127}\n",
      "model: gemini, direct QA performance: {'invalid': 0.0, 'correct': 0.373, 'wrong': 0.593, 'cannot decide': 0.033}\n",
      "model: gemini, CoT performance: {'invalid': 0.0, 'correct': 0.12, 'wrong': 0.293, 'cannot decide': 0.587}\n"
     ]
    }
   ],
   "source": [
    "# with retrieval augmentation\n",
    "print(\"======retrieval augmentation======\")\n",
    "for model in ['gpt4turbo', 'gemini']:\n",
    "\n",
    "    direct_pred_l = []\n",
    "    cot_pred_l = []\n",
    "    cot_whole = []\n",
    "    gold_l = []\n",
    "\n",
    "    for k in range(150):\n",
    "        with open(\"LLM/{}_retrieval_directna_{}.txt\".format(model, k)) as f:\n",
    "            direct = f.read().strip()\n",
    "        with open(\"LLM/{}_retrieval_cot_{}.txt\".format(model, k)) as f:\n",
    "            cot = f.read().strip().lower()\n",
    "        with open(\"LLM/answer_{}.txt\".format(k)) as f:\n",
    "            ans = f.read().strip()\n",
    "        direct_pred_l.append(proc_direct(direct))\n",
    "        cot_whole.append(cot)\n",
    "        cot_pred_l.append(proc_cot(cot))\n",
    "        gold_l.append(int(ans))\n",
    "\n",
    "    print(\"model: {}, direct QA performance:\".format(model), end=\" \")\n",
    "    evaluate(direct_pred_l, gold_l)\n",
    "    print(\"model: {}, CoT performance:\".format(model), end=\" \")\n",
    "    correct_inds = evaluate(cot_pred_l, gold_l, correct_inds=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "CLM",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
